def jax_generate_cov(zs1, zs2, a, ell):
return (a**2) * jnp.exp(-(zs1[:,None] - zs2[None, :])**2 / ell**2)
def jax_gaussian_process(x0_mean, x1_mean, x1, cov_blocks, tau):
cov_11_noise = cov_blocks[1][1] + tau * jnp.eye(x1_mean.shape[0])
x0_mean_cond1 = x0_mean + cov_blocks[0][1].dot( jax.numpy.linalg.solve(cov_11_noise, x1 - x1_mean ) )
cov_00_cond2 = cov_blocks[0][0] - cov_blocks[0][1].dot(jax.numpy.linalg.solve(cov_11_noise, cov_blocks[0][1].T))
return x0_mean_cond1, cov_00_cond2
def jax_make_prediction(z, predicted_mean, gen_cov, obs, obs_means, obs_zs, obs_cov, tau):
z = jnp.array([z])
cov_blocks = [
[ gen_cov(z, z), gen_cov(z, obs_zs) ],
[ gen_cov(z, obs_zs), obs_cov ]
]
pred_mean_cond, pred_var_cond = jax_gaussian_process(jnp.array([predicted_mean]), obs_means, obs, cov_blocks, tau)
return pred_mean_cond[0], pred_var_cond[0][0]
def jax_v_position(x, grid, g_x_edges, g_y_edges):
clamp = lambda ar, min_, max_: jnp.maximum(min_, jnp.minimum(max_, ar))
vi = clamp(jnp.searchsorted(g_x_edges, x[:,0]), 1, g_x_edges.shape[0]-1)
vj = clamp(jnp.searchsorted(g_y_edges, x[:,1]), 1, g_y_edges.shape[0]-1)
return grid[:, vi-1,vj-1,:].T
def get_v_func(gen_cov, t_grid, g_x_edges, g_y_edges, grid_cov, tau):
def v_func(t, x, grid):
obs = jax_v_position(x, grid, g_x_edges, g_y_edges)
rfun = jax.vmap(lambda o: jax_make_prediction(t, o.mean(), gen_cov, o, jnp.ones_like(o)*o.mean(), t_grid, grid_cov, tau)[0])
return rfun(obs.reshape((-1, obs.shape[2]))).reshape((2, obs.shape[1])).T
return v_func
#@functools.partial(jax.jit, static_argnums=(0,1))
def jax_sim_gp(N, save_stride, t_edges, v_func, v_grid, x_0, epsilon):
clamp = lambda ar, min_, max_: jnp.maximum(min_, jnp.minimum(max_, ar))
def get_grid(t):
vt = clamp(jnp.searchsorted(t_edges, t), 1, t_edges.shape[0]-1)
return v_grid[vt-1]
state_0 = (x_0, 0.0, v_func(0.0, x_0, v_grid), *get_grid(0.0).transpose([2,0,1]))
def step(i, state):
x, t, v, _, _ = state
v_t_grid = get_grid(t)
v = v_func(t, x, v_grid)
x = x + epsilon*v # Compute the next position value
t = t + epsilon # Compute the next time
new_state = (x, t, v, v_t_grid[:,:,0], v_t_grid[:,:,1])
return new_state
def save_step(state, i):
new_state = jax.lax.fori_loop(i, i+save_stride, step, state)
return new_state, new_state
_, states = jax.lax.scan(save_step, state_0, xs=jnp.arange(0, N, save_stride))
return states
@functools.partial(jax.jit, static_argnums=(0,1))
def jax_gp(N, save_stride, t_edges, g_x_edges, g_y_edges, v_grid, x_0, epsilon, a, ell, tau):
mid = lambda a: (a[1:] + a[:-1])/2
t_grid, x_grid, y_grid = map(mid, (t_edges, g_x_edges, g_y_edges))
gen_cov = lambda z, w: jax_generate_cov(z, w, a, ell)
grid_cov = gen_cov(t_grid, t_grid)
v_func = get_v_func(gen_cov, t_grid, g_x_edges, g_y_edges, grid_cov, tau)
return jax_sim_gp(N, save_stride, t_edges, v_func, v_grid, x_0, epsilon)